{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from utils import visualization_evaluation,visualization_y\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "outputs": [],
   "source": [
    "# some parameters\n",
    "train_size = 6000\n",
    "\n",
    "blk_sz, sensitivity = 8, 8\n",
    "selected_bands = [127, 201, 202, 294]\n",
    "tree_num = 185\n",
    "pic_row, pic_col= 600, 1024\n",
    "\n",
    "dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_test.p'\n",
    "model_file = f'./models/rf_pca_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_1.model'"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 数据集与样本平衡"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据量：  9600\n",
      "x train (6720, 8, 8, 4), y train (6720,)\n",
      "x test (2880, 8, 8, 4), y test (2880,)\n"
     ]
    }
   ],
   "source": [
    "# 读取数据\n",
    "with open(dataset_file, 'rb') as f:\n",
    "    x_list, y_list = pickle.load(f)\n",
    "# 确保数据当中x和y数量对得上\n",
    "assert len(x_list) == len(y_list)\n",
    "print(\"数据量： \", len(x_list))\n",
    "x, y = np.asarray(x_list), np.asarray(y_list, dtype=int)\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=5,\n",
    "                                                    shuffle=True, stratify=y)\n",
    "print(f\"x train {x_train.shape}, y train {y_train.shape}\\n\"\n",
    "      f\"x test {x_test.shape}, y test {y_test.shape}\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total (array([   0.,    0.,    0.,    0.,    0., 9600.,    0.]), array([0, 1, 2, 3, 4, 5, 6, 7]), <BarContainer object of 7 artists>) \n",
      "train (array([   0.,    0.,    0.,    0.,    0., 6720.,    0.]), array([0, 1, 2, 3, 4, 5, 6, 7]), <BarContainer object of 7 artists>) \n",
      "test (array([   0.,    0.,    0.,    0.,    0., 2880.,    0.]), array([0, 1, 2, 3, 4, 5, 6, 7]), <BarContainer object of 7 artists>)\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 864x432 with 3 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAtIAAAFmCAYAAABA2X1UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAiOUlEQVR4nO3df6yc1X3n8fenNiHkhzdmubCur7WmlZsW0AbKlesuUpSWtrhNFNPVsutICVaF1l3WbclupBbnn7R/WGJXbZqiXZDckGC2NNRLEmFlQ1rXSZSNRCEXQusYh8UbKL61i2+TTeP0D7I43/1jDs1gxjZ+7th37sz7JY2eZ77znJnzJDryh7lnzklVIUmSJOns/NBid0CSJElaigzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVwxiCd5GNJjiX5Wl/t4iR7kzzTjiv7Xtue5FCSp5Pc0Fe/Nsn+9tqdSdLqFyb5k1Z/NMnaId+jJCDJW5M82ff4TpL3D3M8SxqOJK9P8liSv0xyIMnvtLrjVRohOdM60kneDnwXuK+qrmq1/wJ8q6ruSHI7sLKqfivJFcAngPXADwN/DvxYVZ1I8hhwG/AXwGeBO6vq4ST/AfgXVfXvk2wGfrmq/u2ZOn7JJZfU2rVrO962NH4ef/zxv6uqqddybZJlwN8APwVsY0jj+VSf53iVXulM47WF3TdW1XeTXAB8md6Y+1ec4/EKjlnpZKcas8vP1LCqvjTgW+JNwDva+S7gi8BvtfoDVfUi8GySQ8D6JM8BK6rqEYAk9wE3Ag+3Nr/d3utB4L8mSZ0h4a9du5bZ2dkzdV+aGEn++iwuvx74P1X110mGOZ4HcrxKr3Sm8dr+Dfxue3pBexTD/ff3lByz0iudasx2nSN9WVUdBWjHS1t9NXC477q5Vlvdzk+uv6JNVb0E/D3wTwd9aJKtSWaTzM7Pz3fsuiRgM71vr2C44/kfOV6lhUmyLMmTwDFgb1U9yjkar+3zHLPSWRr2jw0Hzbuq09RP1+bVxaqdVTVTVTNTU6/pL9iSTpLkdcC7gf9xpksH1M40nn9QcLxKC1JVJ6rqamCa3rfLV53m8gWN1/Z5jlnpLHUN0i8kWQXQjsdafQ5Y03fdNHCk1acH1F/RJsly4J8A3+rYL0ln9ovAE1X1Qns+zPEsaciq6tv0pnBsxPEqjZSuQXoPsKWdbwEe6qtvbitxXA6sAx5rf346nmRD+wHFzSe1efm9/jXw+TPNj5a0IO/hB9M6YLjjWdIQJJlK8pZ2fhHwc8DXcbxKI+WMPzZM8gl6P2y4JMkc8CHgDmB3kluA54GbAKrqQJLdwFPAS8C2qjrR3upW4F7gIno/cnj5hw73AP+9/TDiW/Tmbko6B5K8Afh54Ff7ysMcz5KGYxWwq62w80PA7qr6TJJHcLxKI+OMy9+NqpmZmfIXxdIPJHm8qmYWux+DOF6lVxrl8QqOWelkpxqz7mwoSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUwRk3ZJGWqrW3/89z/hnP3fHOc/4Z0iQ4H+MVHLPSsPhvbI/fSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVgkJYkSZI6MEhLkiRJHRikJUmSpA4M0pIkSVIHBmlJkiSpA4O0JEmS1IFBWpIkSerAIC1JkiR1YJCWJEmSOjBIS5IkSR0YpCVJkqQODNKSJElSBwZpSZIkqQODtCRJktSBQVqSJEnqwCAtSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFamiBJ3pLkwSRfT3IwyU8nuTjJ3iTPtOPKvuu3JzmU5OkkN/TVr02yv712Z5Iszh1JkrR4DNLSZPkD4HNV9ePA24CDwO3AvqpaB+xrz0lyBbAZuBLYCNyVZFl7n7uBrcC69th4Pm9CkqRRYJCWJkSSFcDbgXsAqup7VfVtYBOwq122C7ixnW8CHqiqF6vqWeAQsD7JKmBFVT1SVQXc19dGkqSJYZCWJsePAPPAx5N8NclHk7wRuKyqjgK046Xt+tXA4b72c622up2fXH+FJFuTzCaZnZ+fH/7dSJK0yAzS0uRYDvwkcHdVXQP8A20axykMmvdcp6m/slC1s6pmqmpmamqqS38lSRppBmlpcswBc1X1aHv+IL1g/UKbrkE7Huu7fk1f+2ngSKtPD6hLkjRRDNLShKiqvwUOJ3lrK10PPAXsAba02hbgoXa+B9ic5MIkl9P7UeFjbfrH8SQb2modN/e1kSRpYixf7A5IOq9+Hbg/yeuAbwC/Qu8/qHcnuQV4HrgJoKoOJNlNL2y/BGyrqhPtfW4F7gUuAh5uD0mSJopBWpogVfUkMDPgpetPcf0OYMeA+ixw1VA7J+kfJVlDb0WcfwZ8H9hZVX+Q5LeBf0fvh8MAH6yqz7Y224FbgBPAb1TVn7b6tfzgP3w/C9zWVtyRtEAGaUmSRs9LwAeq6okkbwYeT7K3vfb7VfW7/ReftO77DwN/nuTH2l+RXl73/S/oBemN+FckaSicIy1J0oipqqNV9UQ7P05v86RXLTPZx3XfpUVgkJYkaYQlWQtcA7y84s6vJfmrJB9LsrLVFrTue/sc136XzpJBWpKkEZXkTcAngfdX1XfoTdP4UeBq4Cjwey9fOqD5a173HVz7XerCIC1J0ghKcgG9EH1/VX0KoKpeqKoTVfV94A+B9e1y132XFoFBWpKkEdPWaL8HOFhVH+6rr+q77JeBr7Vz132XFoGrdkiSNHquA94H7E/yZKt9EHhPkqvpTc94DvhVcN13abEYpCVJGjFV9WUGz2/+7GnauO67dJ45tUOSJEnqwCAtSZIkdbCgIJ3kPyY5kORrST6R5PVJLk6yN8kz7biy7/rtSQ4leTrJDX31a5Psb6/d2X4QIUmSJI2szkE6yWrgN4CZqroKWEZve9LbgX1VtQ7Y156fvH3pRuCuJMva2728fem69tjYtV+SJEnS+bDQqR3LgYuSLAfeQG9tyk3Arvb6Ln6wFanbl0qSJGlsdA7SVfU3wO8Cz9PbXenvq+rPgMvaupW046WtiduXSpIkaWwsZGrHSnrfMl8O/DDwxiTvPV2TATW3L5UkSdKStJCpHT8HPFtV81X1/4BPAf8SeOHlnZfa8Vi73u1LJUmSNDYWEqSfBzYkeUNbZeN64CC9bUq3tGu28IOtSN2+VJIkSWOj886GVfVokgeBJ+htR/pVYCfwJmB3klvohe2b2vVuXypJkqSxsaAtwqvqQ8CHTiq/SO/b6UHXu32pJEmSxoI7G0qSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVgkJYkSZI6MEhLkiRJHRikJUmSpA4M0pIkSVIHBmlJkiSpA4O0JEmS1IFBWpIkSerAIC1JkiR1YJCWJEmSOjBIS5IkSR0YpCVJkqQODNKSJElSBwZpaYIkeS7J/iRPJplttYuT7E3yTDuu7Lt+e5JDSZ5OckNf/dr2PoeS3Jkki3E/kiQtJoO0NHl+pqqurqqZ9vx2YF9VrQP2teckuQLYDFwJbATuSrKstbkb2Aqsa4+N57H/kiSNBIO0pE3Arna+C7ixr/5AVb1YVc8Ch4D1SVYBK6rqkaoq4L6+NpIkTQyDtDRZCvizJI8n2dpql1XVUYB2vLTVVwOH+9rOtdrqdn5y/RWSbE0ym2R2fn5+yLchSdLiW77YHZB0Xl1XVUeSXArsTfL101w7aN5znab+ykLVTmAnwMzMzKtelyRpqfMbaWmCVNWRdjwGfBpYD7zQpmvQjsfa5XPAmr7m08CRVp8eUJckaaIYpKUJkeSNSd788jnwC8DXgD3AlnbZFuChdr4H2JzkwiSX0/tR4WNt+sfxJBvaah0397WRJGliOLVDmhyXAZ9uK9UtB/64qj6X5CvA7iS3AM8DNwFU1YEku4GngJeAbVV1or3XrcC9wEXAw+0hSdJEMUhLE6KqvgG8bUD9m8D1p2izA9gxoD4LXDXsPkqStJQ4tUOSJEnqwCAtSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSRkySNUm+kORgkgNJbmv1i5PsTfJMO67sa7M9yaEkTye5oa9+bZL97bU720ZKkobAIC1J0uh5CfhAVf0EsAHYluQK4HZgX1WtA/a157TXNgNXAhuBu5Isa+91N7CV3u6k69rrkobAIC1J0oipqqNV9UQ7Pw4cBFYDm4Bd7bJdwI3tfBPwQFW9WFXPAoeA9UlWASuq6pGqKuC+vjaSFsggLUnSCEuyFrgGeBS4rKqOQi9sA5e2y1YDh/uazbXa6nZ+cn3Q52xNMptkdn5+fqj3II0rg7QkSSMqyZuATwLvr6rvnO7SAbU6Tf3VxaqdVTVTVTNTU1Nn31lpAhmkJUkaQUkuoBei76+qT7XyC226Bu14rNXngDV9zaeBI60+PaAuaQgM0pIkjZi2ssY9wMGq+nDfS3uALe18C/BQX31zkguTXE7vR4WPtekfx5NsaO95c18bSQu0fLE7IEmSXuU64H3A/iRPttoHgTuA3UluAZ4HbgKoqgNJdgNP0VvxY1tVnWjtbgXuBS4CHm4PSUNgkJYkacRU1ZcZPL8Z4PpTtNkB7BhQnwWuGl7vJL3MqR2SJElSBwZpSZIkqQODtCRJktSBQVqSJEnqwCAtSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHWwoCCd5C1JHkzy9SQHk/x0kouT7E3yTDuu7Lt+e5JDSZ5OckNf/dok+9trdybJQvolSZIknWsL/Ub6D4DPVdWPA28DDgK3A/uqah2wrz0nyRXAZuBKYCNwV5Jl7X3uBrYC69pj4wL7JUmSJJ1TnYN0khXA24F7AKrqe1X1bWATsKtdtgu4sZ1vAh6oqher6lngELA+ySpgRVU9UlUF3NfXRpIkSRpJC/lG+keAeeDjSb6a5KNJ3ghcVlVHAdrx0nb9auBwX/u5Vlvdzk+uv0qSrUlmk8zOz88voOuSJEnSwiwkSC8HfhK4u6quAf6BNo3jFAbNe67T1F9drNpZVTNVNTM1NXW2/ZUkSZKGZiFBeg6Yq6pH2/MH6QXrF9p0DdrxWN/1a/raTwNHWn16QF2SJEkaWZ2DdFX9LXA4yVtb6XrgKWAPsKXVtgAPtfM9wOYkFya5nN6PCh9r0z+OJ9nQVuu4ua+NJEmSNJKWL7D9rwP3J3kd8A3gV+iF891JbgGeB24CqKoDSXbTC9svAduq6kR7n1uBe4GLgIfbQ5IkSRpZCwrSVfUkMDPgpetPcf0OYMeA+ixw1UL6IkmSJJ1P7mwoTZAky9oqO59pz91ASZKkjgzS0mS5jd7GSS9zAyVJkjoySEsTIsk08E7go31lN1CSJKkjg7Q0OT4C/Cbw/b6aGyhJktSRQVqaAEneBRyrqsdfa5MBNTdQkiSpz0KXv5O0NFwHvDvJLwGvB1Yk+SPaBkpVddQNlCRJOjt+Iy1NgKraXlXTVbWW3o8IP19V78UNlCRJ6sxvpKXJdgduoCRJUicGaWnCVNUXgS+282/iBkqSJHXi1A5JkiSpA4O0JEmS1IFBWpIkSerAIC1JkiR1YJCWJEmSOjBIS5IkSR0YpCVJkqQODNKSJElSBwZpSZIkqQODtCRJktSBQVqSJEnqwCAtSZIkdWCQliRJkjowSEuSJEkdGKQlSRoxST6W5FiSr/XVfjvJ3yR5sj1+qe+17UkOJXk6yQ199WuT7G+v3Zkk5/tepHFmkJYkafTcC2wcUP/9qrq6PT4LkOQKYDNwZWtzV5Jl7fq7ga3AuvYY9J6SOjJIS5I0YqrqS8C3XuPlm4AHqurFqnoWOASsT7IKWFFVj1RVAfcBN56TDksTyiAtSdLS8WtJ/qpN/VjZaquBw33XzLXa6nZ+cn2gJFuTzCaZnZ+fH3a/pbFkkJYkaWm4G/hR4GrgKPB7rT5o3nOdpj5QVe2sqpmqmpmamlpgV6XJYJCWJGkJqKoXqupEVX0f+ENgfXtpDljTd+k0cKTVpwfUJQ2JQVqSpCWgzXl+2S8DL6/osQfYnOTCJJfT+1HhY1V1FDieZENbreNm4KHz2mlpzC1f7A5IkqRXSvIJ4B3AJUnmgA8B70hyNb3pGc8BvwpQVQeS7AaeAl4CtlXVifZWt9JbAeQi4OH2kDQkBmlJkkZMVb1nQPme01y/A9gxoD4LXDXErknq49QOSZIkqQODtCRJktSBQVqSJEnqwCAtSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVgkJYkSZI6MEhLkiRJHRikpQmR5PVJHkvyl0kOJPmdVr84yd4kz7Tjyr4225McSvJ0khv66tcm2d9euzNJFuOeJElaTAZpaXK8CPxsVb0NuBrYmGQDcDuwr6rWAfvac5JcAWwGrgQ2AnclWdbe625gK7CuPTaex/uQJGkkGKSlCVE9321PL2iPAjYBu1p9F3BjO98EPFBVL1bVs8AhYH2SVcCKqnqkqgq4r6+NJEkTwyAtTZAky5I8CRwD9lbVo8BlVXUUoB0vbZevBg73NZ9rtdXt/OT6yZ+1Nclsktn5+fmh34skSYvNIC1NkKo6UVVXA9P0vl2+6jSXD5r3XKepn/xZO6tqpqpmpqamOvVXkqRRZpCWJlBVfRv4Ir25zS+06Rq047F22Rywpq/ZNHCk1acH1CVJmigGaWlCJJlK8pZ2fhHwc8DXgT3AlnbZFuChdr4H2JzkwiSX0/tR4WNt+sfxJBvaah0397WRJGliLF/sDkg6b1YBu9rKGz8E7K6qzyR5BNid5BbgeeAmgKo6kGQ38BTwErCtqk6097oVuBe4CHi4PSRJmigGaWlCVNVfAdcMqH8TuP4UbXYAOwbUZ4HTza+WJGnsLXhqR1sF4KtJPtOeu7mDJEmSxt4w5kjfBhzse+7mDpIkSRp7CwrSSaaBdwIf7Su7uYMkSZLG3kK/kf4I8JvA9/tq52RzB3CDB0mSJI2OzkE6ybuAY1X1+GttMqD2mjd3ADd4kCRJ0uhYyKod1wHvTvJLwOuBFUn+iLa5Q1UddXMHSZIkjavO30hX1faqmq6qtfR+RPj5qnovbu4gSZKkCXAu1pG+Azd3kCRJ0pgbSpCuqi8CX2znbu4gSZKksTeMdaQlSZKkiWOQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVgkJYkSZI6MEhLkiRJHRikJUmSpA4M0pIkSVIHBmlJkkZMko8lOZbka321i5PsTfJMO67se217kkNJnk5yQ1/92iT722t3Jsn5vhdpnBmkJUkaPfcCG0+q3Q7sq6p1wL72nCRXAJuBK1ubu5Isa23uBrYC69rj5PeUtAAGaUmSRkxVfQn41knlTcCudr4LuLGv/kBVvVhVzwKHgPVJVgErquqRqirgvr42kobAIC1J0tJwWVUdBWjHS1t9NXC477q5Vlvdzk+uD5Rka5LZJLPz8/ND7bg0rgzSkiQtbYPmPddp6gNV1c6qmqmqmampqaF1ThpnBmlJkpaGF9p0DdrxWKvPAWv6rpsGjrT69IC6pCExSEuStDTsAba08y3AQ331zUkuTHI5vR8VPtamfxxPsqGt1nFzXxtJQ7B8sTsgSZJeKckngHcAlySZAz4E3AHsTnIL8DxwE0BVHUiyG3gKeAnYVlUn2lvdSm8FkIuAh9tD0pAYpCVJGjFV9Z5TvHT9Ka7fAewYUJ8Frhpi1yT1cWqHJEmS1IFBWpIkSerAIC1JkiR1YJCWJEmSOjBIS5IkSR0YpCVJkqQODNLShEiyJskXkhxMciDJba1+cZK9SZ5px5V9bbYnOZTk6SQ39NWvTbK/vXZn2+xBkqSJYpCWJsdLwAeq6ieADcC2JFcAtwP7qmodsK89p722GbgS2AjclWRZe6+7ga30dlBb116XJGmiGKSlCVFVR6vqiXZ+HDgIrAY2AbvaZbuAG9v5JuCBqnqxqp4FDgHrk6wCVlTVI1VVwH19bSRJmhgGaWkCJVkLXAM8ClxWVUehF7aBS9tlq4HDfc3mWm11Oz+5fvJnbE0ym2R2fn5+6PcgSdJiM0hLEybJm4BPAu+vqu+c7tIBtTpN/ZWFqp1VNVNVM1NTU906K0nSCDNISxMkyQX0QvT9VfWpVn6hTdegHY+1+hywpq/5NHCk1acH1CVJmigGaWlCtJU17gEOVtWH+17aA2xp51uAh/rqm5NcmORyej8qfKxN/zieZEN7z5v72kiSNDGWL3YHJJ031wHvA/YnebLVPgjcAexOcgvwPHATQFUdSLIbeIreih/bqupEa3crcC9wEfBwe0iSNFEM0tKEqKovM3h+M8D1p2izA9gxoD4LXDW83kmStPQ4tUOSJEnqwCAtSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVgkJYkSZI6MEhLkiRJHRikJUmSpA4M0pIkSVIHBmlJkiSpA4O0JEmS1IFBWpIkSerAIC1JkiR1YJCWJEmSOugcpJOsSfKFJAeTHEhyW6tfnGRvkmfacWVfm+1JDiV5OskNffVrk+xvr92ZJAu7LUmSJOncWsg30i8BH6iqnwA2ANuSXAHcDuyrqnXAvvac9tpm4EpgI3BXkmXtve4GtgLr2mPjAvolSZIknXOdg3RVHa2qJ9r5ceAgsBrYBOxql+0Cbmznm4AHqurFqnoWOASsT7IKWFFVj1RVAff1tZEkSZJG0lDmSCdZC1wDPApcVlVHoRe2gUvbZauBw33N5lptdTs/uT7oc7YmmU0yOz8/P4yuS5IkSZ0sOEgneRPwSeD9VfWd0106oFanqb+6WLWzqmaqamZqaursOytJkiQNyYKCdJIL6IXo+6vqU638QpuuQTsea/U5YE1f82ngSKtPD6hLkiRJI2shq3YEuAc4WFUf7ntpD7ClnW8BHuqrb05yYZLL6f2o8LE2/eN4kg3tPW/uayNJkiSNpOULaHsd8D5gf5InW+2DwB3A7iS3AM8DNwFU1YEku4Gn6K34sa2qTrR2twL3AhcBD7eHJEmSNLI6B+mq+jKD5zcDXH+KNjuAHQPqs8BVXfsiSZIknW/ubChJkiR1YJCWJEmSOjBIS5K0hCR5Lsn+JE8mmW21i5PsTfJMO67su357kkNJnk5yw+L1XBo/BmlJkpaen6mqq6tqpj2/HdhXVeuAfe05Sa4ANgNXAhuBu5IsW4wOS+PIIC1J0tK3CdjVzncBN/bVH6iqF6vqWeAQsP78d08aTwZpSZKWlgL+LMnjSba22mVtXwba8dJWXw0c7ms712qvkmRrktkks/Pz8+eo69J4Wcg60pIk6fy7rqqOJLkU2Jvk66e5dtAytTXowqraCewEmJmZGXiNpFfyG2lJkpaQqjrSjseAT9ObqvFCklUA7XisXT4HrOlrPg0cOX+9lcabQVqSpCUiyRuTvPnlc+AXgK8Be4At7bItwEPtfA+wOcmFSS4H1gGPnd9eS+PLqR2SJC0dlwGfTgK9f8P/uKo+l+QrwO4ktwDPAzcBVNWBJLuBp4CXgG1VdWJxui6NH4O0NCGSfAx4F3Csqq5qtYuBPwHWAs8B/6aq/m97bTtwC3AC+I2q+tNWvxa4F7gI+CxwW1U5n1I6D6rqG8DbBtS/CVx/ijY7gB3nuGvSRHJqhzQ57qW3jmy/LmvP3g1spfcn4nUD3lOSpIlgkJYmRFV9CfjWSeWzWnu2/YhpRVU90r6Fvq+vjSRJE8UgLU22s117dnU7P7n+Kq5JK0kadwZpSYOcau3Zs1qTtqpmqmpmampqqJ2TJGkUGKSlyXa2a8/OtfOT65IkTRyDtDTZzmrt2Tb943iSDemtv3VzXxtJkiaKy99JEyLJJ4B3AJckmQM+BNzB2a89eys/WP7u4faQJGniGKSlCVFV7znFS2e19mxVzQJXDbFrkiQtSU7tkCRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVgkJYkSZI6MEhLkiRJHRikJUmSpA4M0pIkSVIHBmlJkiSpA4O0JEmS1IFBWpIkSerAIC1JkiR1YJCWJEmSOjBIS5IkSR0YpCVJkqQODNKSJElSBwZpSZIkqQODtCRJktSBQVqSJEnqwCAtSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKkDg7QkSZLUgUFakiRJ6sAgLUmSJHVgkJYkSZI6MEhLkiRJHRikJUmSpA4M0pIkSVIHBmlJkiSpg5EJ0kk2Jnk6yaEkty92fySdnmNWWjocr9K5MRJBOsky4L8BvwhcAbwnyRWL2ytJp+KYlZYOx6t07oxEkAbWA4eq6htV9T3gAWDTIvdJ0qk5ZqWlw/EqnSPLF7sDzWrgcN/zOeCnTr4oyVZga3v63SRPn+F9LwH+big9XHzeywjKfx6pe/nn5/GzzjhmJ3y8wnjdz9jcywiN2ZEarzDxY9Z7GUEjNF7hFGN2VIJ0BtTqVYWqncDO1/ymyWxVzSykY6PCexlN43QvZ+mMY3aSxyuM1/14L0ue/8aegfcympbCvYzK1I45YE3f82ngyCL1RdKZOWalpcPxKp0joxKkvwKsS3J5ktcBm4E9i9wnSafmmJWWDserdI6MxNSOqnopya8BfwosAz5WVQeG8Nav+U9US4D3MprG6V5es3M0Zsftf8txuh/vZQnz39jXxHsZTSN/L6l61TQpSZIkSWcwKlM7JEmSpCXFIC1JkiR1MJZBepy2Qk2yJskXkhxMciDJbYvdp4VKsizJV5N8ZrH7shBJ3pLkwSRfb////PRi92mpGpcx63gdXY7X4XG8jq5xGa+wdMbs2M2Rbluh/m/g5+kt+fMV4D1V9dSidqyjJKuAVVX1RJI3A48DNy7V+wFI8p+AGWBFVb1rsfvTVZJdwP+qqo+2X8K/oaq+vcjdWnLGacw6XkeX43U4HK+jbVzGKyydMTuO30iP1VaoVXW0qp5o58eBg/R2qVqSkkwD7wQ+uth9WYgkK4C3A/cAVNX3RnGALxFjM2Ydr6PJ8TpUjtcRNS7jFZbWmB3HID1oK9QlOzD6JVkLXAM8ushdWYiPAL8JfH+R+7FQPwLMAx9vf0b7aJI3LnanlqixHLOO15HieB0ex+vo+gjjMV5hCY3ZcQzSr2kr1KUmyZuATwLvr6rvLHZ/ukjyLuBYVT2+2H0ZguXATwJ3V9U1wD8AS3au4CIbuzHreB05jtfhcbyOoDEbr7CExuw4Bumx2wo1yQX0Bvn9VfWpxe7PAlwHvDvJc/T+HPizSf5ocbvU2RwwV1Uvf3vxIL1Br7M3VmPW8TqSHK/D43gdTeM0XmEJjdlxDNJjtRVqktCbI3Swqj682P1ZiKraXlXTVbWW3v8vn6+q9y5ytzqpqr8FDid5aytdDyzZH6gssrEZs47X0eR4HSrH6wgap/EKS2vMjsQW4cN0DrdCXSzXAe8D9id5stU+WFWfXbwuqfl14P72j8k3gF9Z5P4sSWM2Zh2vo8vxOgSOV51HS2LMjt3yd5IkSdL5MI5TOyRJkqRzziAtSZIkdWCQliRJkjowSEuSJEkdGKQlSZKkDgzSkiRJUgcGaUmSJKmD/w96Ngq/+OO+CgAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(12, 6))\n",
    "hist_res_total = axs[0].hist(y,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "hist_res_train = axs[1].hist(y_train,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "hist_res_test = axs[2].hist(y_test,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "print(f'total {hist_res_total} \\n'\n",
    "      f'train {hist_res_train} \\n'\n",
    "      f'test {hist_res_test}')\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "The target 'y' needs to have more than 1 class. Got 1 class instead",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mValueError\u001B[0m                                Traceback (most recent call last)",
      "Input \u001B[1;32mIn [33]\u001B[0m, in \u001B[0;36m<cell line: 6>\u001B[1;34m()\u001B[0m\n\u001B[0;32m      4\u001B[0m x_train_shape \u001B[38;5;241m=\u001B[39m x_train\u001B[38;5;241m.\u001B[39mshape\n\u001B[0;32m      5\u001B[0m x_train \u001B[38;5;241m=\u001B[39m x_train\u001B[38;5;241m.\u001B[39mreshape((x_train\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m))\n\u001B[1;32m----> 6\u001B[0m x_resampled, y_resampled \u001B[38;5;241m=\u001B[39m \u001B[43mros\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit_resample\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my_train\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m      7\u001B[0m \u001B[38;5;66;03m# 画图\u001B[39;00m\n\u001B[0;32m      8\u001B[0m fig, axs \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39msubplots(figsize\u001B[38;5;241m=\u001B[39m(\u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m6\u001B[39m))\n",
      "File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\imblearn\\base.py:79\u001B[0m, in \u001B[0;36mSamplerMixin.fit_resample\u001B[1;34m(self, X, y)\u001B[0m\n\u001B[0;32m     76\u001B[0m arrays_transformer \u001B[38;5;241m=\u001B[39m ArraysTransformer(X, y)\n\u001B[0;32m     77\u001B[0m X, y, binarize_y \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_check_X_y(X, y)\n\u001B[1;32m---> 79\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39msampling_strategy_ \u001B[38;5;241m=\u001B[39m \u001B[43mcheck_sampling_strategy\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m     80\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msampling_strategy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_sampling_type\u001B[49m\n\u001B[0;32m     81\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     83\u001B[0m output \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_fit_resample(X, y)\n\u001B[0;32m     85\u001B[0m y_ \u001B[38;5;241m=\u001B[39m (\n\u001B[0;32m     86\u001B[0m     label_binarize(output[\u001B[38;5;241m1\u001B[39m], classes\u001B[38;5;241m=\u001B[39mnp\u001B[38;5;241m.\u001B[39munique(y)) \u001B[38;5;28;01mif\u001B[39;00m binarize_y \u001B[38;5;28;01melse\u001B[39;00m output[\u001B[38;5;241m1\u001B[39m]\n\u001B[0;32m     87\u001B[0m )\n",
      "File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\imblearn\\utils\\_validation.py:500\u001B[0m, in \u001B[0;36mcheck_sampling_strategy\u001B[1;34m(sampling_strategy, y, sampling_type, **kwargs)\u001B[0m\n\u001B[0;32m    494\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m    495\u001B[0m         \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124msampling_type\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m should be one of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mSAMPLING_KIND\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m. \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    496\u001B[0m         \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mGot \u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;132;01m{\u001B[39;00msampling_type\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m instead.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    497\u001B[0m     )\n\u001B[0;32m    499\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m np\u001B[38;5;241m.\u001B[39munique(y)\u001B[38;5;241m.\u001B[39msize \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[1;32m--> 500\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m    501\u001B[0m         \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mThe target \u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124my\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m needs to have more than 1 class. \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    502\u001B[0m         \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mGot \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mnp\u001B[38;5;241m.\u001B[39munique(y)\u001B[38;5;241m.\u001B[39msize\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m class instead\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m    503\u001B[0m     )\n\u001B[0;32m    505\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m sampling_type \u001B[38;5;129;01min\u001B[39;00m (\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mensemble\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mbypass\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[0;32m    506\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m sampling_strategy\n",
      "\u001B[1;31mValueError\u001B[0m: The target 'y' needs to have more than 1 class. Got 1 class instead"
     ]
    }
   ],
   "source": [
    "# 发现样本实在是不平衡\n",
    "from imblearn.over_sampling import RandomOverSampler\n",
    "ros = RandomOverSampler(random_state=0)\n",
    "x_train_shape = x_train.shape\n",
    "x_train = x_train.reshape((x_train.shape[0], -1))\n",
    "x_resampled, y_resampled = ros.fit_resample(x_train, y_train)\n",
    "# 画图\n",
    "fig, axs = plt.subplots(figsize=(10, 6))\n",
    "hist_res_train = axs.hist(y_resampled, [0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
    "print(f'train {hist_res_train} \\n')\n",
    "plt.show()\n",
    "# 抽样\n",
    "x_train, _, y_train, _ =  train_test_split(x_resampled, y_resampled, train_size=train_size, random_state=0,\n",
    "                                           shuffle=True, stratify=y_resampled)\n",
    "# 画图\n",
    "fig, axs = plt.subplots(figsize=(10, 6))\n",
    "hist_res_train = axs.hist(y_train, [0, 1, 2, 3, 4, 5, 6 ,7], align='mid')\n",
    "print(f'train {hist_res_train} \\n')\n",
    "plt.show()\n",
    "x_train = x_train.reshape(x_train.shape[0], x_train_shape[1], x_train_shape[2], x_train_shape[3])\n",
    "print(len(x_train))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}